-
Notifications
You must be signed in to change notification settings - Fork 18
Dev/qwen omni #47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dev/qwen omni #47
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can probably refactor this changes into a new qwen omni dataset that inherit from this and override and keep vision_audio_dataset unchanged
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can override this logic in qwen2_5_omni_processor and keep this file unchange
src/lmms_engine/mapping_func.py
Outdated
| # for Qwen2.5-Omni, use ThinkerForConditionalGeneration | ||
| if config.model_type == "qwen2_5_omni": | ||
| from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import ( | ||
| Qwen2_5OmniThinkerForConditionalGeneration, | ||
| ) | ||
|
|
||
| class Qwen2_5OmniThinkerForConditionalGenerationWithDtype( | ||
| Qwen2_5OmniThinkerForConditionalGeneration | ||
| ): | ||
| @classmethod | ||
| def from_pretrained(cls, *args, **kwargs): | ||
| if "torch_dtype" not in kwargs: | ||
| kwargs["torch_dtype"] = "auto" | ||
| model = super().from_pretrained(*args, **kwargs) | ||
| import torch | ||
|
|
||
| model = model.to(dtype=torch.bfloat16) | ||
| return model | ||
|
|
||
| return Qwen2_5OmniThinkerForConditionalGenerationWithDtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Qwen2_5OmniThinkerForConditionalGeneration itself is a PretrainedModel and can use from_pretrained. This part is redundant. Register it as an auto causal lm and remove this
| if rms_norm: | ||
| modeling_qwen2_5_omni.Qwen2RMSNorm = LigerRMSNorm | ||
| if cross_entropy: | ||
| modeling_qwen2_5_omni.CrossEntropyLoss = LigerCrossEntropyLoss | ||
| if fused_linear_cross_entropy: | ||
| modeling_qwen2_5_omni.Qwen2_5OmniThinkerForConditionalGeneration.forward = ( | ||
| qwen2_5_omni_lce_forward | ||
| ) | ||
| if swiglu: | ||
| modeling_qwen2_5_omni.Qwen2MLP = LigerSwiGLUMLP |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Qwen2.5 Omni don't have Qwen2RMSNorm or Qwen2MLP. This won't patch the correct module. Check https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py for correct names
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
correct me again if im wrong but Qwen2.5Omni uses
Qwen2RMSNorm: https://github.com/huggingface/transformers/blob/0419ff881d7bb503f4fc0f0a7a5aac3d012c9b91/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py#L46C36-L46C48
and also Qwen2MLP:
https://github.com/huggingface/transformers/blob/0419ff881d7bb503f4fc0f0a7a5aac3d012c9b91/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py#L1401
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I scan to this https://github.com/huggingface/transformers/blob/0419ff881d7bb503f4fc0f0a7a5aac3d012c9b91/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py#L976-L987. Not sure if it can also be patched. Yeah but should be fined to patch these two with Qwen2RMSNorm and Qwen2MLP. Should be fine for now, no need to change for this part.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not actually remove padding from the ops right now. Input ids will be 2D and I suspect you are directly using the Qwen2.5 VL 's forward? This could be buggy
src/lmms_engine/models/utils.py
Outdated
| elif config.model_type == "qwen2_5_omni": | ||
| self.config = ( | ||
| config.text_config if hasattr(config, "text_config") else config | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If already reset config here, actually no need to rewrite the _estimate_qwen2_flops later
61cc0cd to
f6b85ce
Compare
|
src/lmms_engine/mapping_func.py
Outdated
|
|
||
| # for Qwen2.5-Omni, load only the thinker model (not full model with talker), hence need handle differently otherwise it will load the full model | ||
| if hasattr(config, "model_type") and config.model_type == "qwen2_5_omni": | ||
|
|
||
| class Qwen2_5OmniThinkerLoader: | ||
| @staticmethod | ||
| def from_pretrained(pretrained_model_name_or_path, *args, **kwargs): | ||
| # Load the full model first | ||
| full_model = AutoModelForCausalLM.from_pretrained( | ||
| pretrained_model_name_or_path, *args, **kwargs | ||
| ) | ||
| # Extract and return only the thinker | ||
| thinker_model = full_model.thinker | ||
| # Clean up full model to save memory | ||
| if hasattr(full_model, "talker"): | ||
| del full_model.talker | ||
| if hasattr(full_model, "token2wav"): | ||
| del full_model.token2wav | ||
| del full_model | ||
| return thinker_model | ||
|
|
||
| return Qwen2_5OmniThinkerLoader | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part should be removed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, I think most of the parts look fine right now. A few suggestion that can improve this PR further. I see that we are only training the Thinker part so I think you can wrap the actual training model into a PretrainedModel and save if as a pretrained checkpoint. Then you can simply use from_pretrained using that and rename the model into qwen2_5_omni_thinker (can push the converted checkpoint to hub also, in lmms-lab maybe)
|
To clarify more, my feeling is that you can actually dump away all the talker module. You can do it as follow You can see here that thinker module itself is a pretrainedmodel and has its own model type So you can
You might need to change your patching logic accordingly though, because the model structure changes a bit |
For the sp alltoall ops, this could possibly caused by the tensor shape mismatch during the gather operations so it is stucked there. May need to check the splitted size on each rank and validate whether the gather size is being set correctly |
Removed comment about using exist_ok=True in AutoConfig registration.
|
I have extracted the thinker only and pushed it here "ngqtrung/Qwen2.5-Omni-Thinker-7B". When i move it to lmms lab hub, it show "The storage patterns of org lmms-lab tripped our internal systems! Please contact us at [email protected] so we can verify your account and unlock more storage for your use-case." so currently, i keep it in my hub. |
|
Hi trung, thank you for all the efforts. I think the most of the part looks nice and LGTM. Can you resolve the conflict and change the example config into the new format? Since we are using hydra now, we can either choose to use a config yaml or simply overriding args with cli. Can check the examples/qwen3_vl for detail information. Thanks! |
|
Please don't use merge because it will merge every commits into the main branch and causing a non-linear history. Instead use squash and merge. I will revert this PR for now and re PR this feature |
|
Oh, actually on my testing envs, this also breaks some of my cicd in some cases. Let me see if we can fix it together in a new PR |
|
Sure, i can help with the fixing. What the current issue? |
|
It's okay, I fix the import error already |
Motivation
Support Qwen 2.5 Omni
Modifications
Checklist